{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Start with your own case \n",
    "\n",
    "In addition to the rich collcetion of **datasets**, **models** and **evaluation metrics**, **FederatedScope** also allows to create your own or introduce more to our package.\n",
    "\n",
    "We provide `register` function to help build your own federated learning workflow.  This introduction will help you to start with your own case:\n",
    "\n",
    "1. [Load a dataset](#data)\n",
    "2. [Build a model](#model) \n",
    "3. [Create a trainer](#trainer)\n",
    "4. [Introduce more evaluation metrics](#metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## <span id=\"data\">1. Load a dataset</span>\n",
    "\n",
    "We provide a function federatedscope.register.register_data to make your dataset available with three steps:\n",
    "\n",
    "* Step1: set up your data in the following format (standalone):\n",
    "    \n",
    "    **Note**: This function returns a `dict`, where the `key` is the client's id, and the `value` is the data `dict` of each client with 'train', 'test' or 'val'.  You can also modify the config here.\n",
    "\n",
    "    We take `torchvision.datasets.MNIST`, which is split and assigned to two clients, as an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-31T10:13:12.308497Z",
     "start_time": "2022-03-31T10:13:12.302160Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def load_my_data(config):\n",
    "    import numpy as np\n",
    "    from torchvision import transforms\n",
    "    from torchvision.datasets import MNIST\n",
    "    from torch.utils.data import DataLoader\n",
    "\n",
    "    # Build data\n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.9637], std=[0.1592])\n",
    "    ])\n",
    "    data_train = MNIST(root=config.data.root, train=True, transform=transform, download=True)\n",
    "    data_test = MNIST(root=config.data.root, train=False, transform=transform, download=True)\n",
    "\n",
    "    # Split data into dict\n",
    "    data_dict = dict()\n",
    "    train_per_client = len(data_train) // config.federate.client_num\n",
    "    test_per_client = len(data_test) // config.federate.client_num\n",
    "\n",
    "    for client_idx in range(1, config.federate.client_num + 1):\n",
    "        dataloader_dict = {\n",
    "            'train':\n",
    "            DataLoader([\n",
    "                data_train[i]\n",
    "                for i in range((client_idx - 1) *\n",
    "                               train_per_client, client_idx * train_per_client)\n",
    "            ],\n",
    "                       config.data.batch_size,\n",
    "                       shuffle=config.data.shuffle),\n",
    "            'test':\n",
    "            DataLoader([\n",
    "                data_test[i]\n",
    "                for i in range((client_idx - 1) * test_per_client, client_idx *\n",
    "                               test_per_client)\n",
    "            ],\n",
    "                       config.data.batch_size,\n",
    "                       shuffle=False)\n",
    "        }\n",
    "        data_dict[client_idx] = dataloader_dict\n",
    "\n",
    "    return data_dict, config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "* Step2: register your data with a keyword, such as `\"mydata\"`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-31T10:13:12.313727Z",
     "start_time": "2022-03-31T10:13:12.309767Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from federatedscope.register import register_data\n",
    "\n",
    "def call_my_data(config, client_cfgs=None):\n",
    "    if config.data.type == \"mycvdata\":\n",
    "        data, modified_config = load_my_data(config)\n",
    "        return data, modified_config\n",
    "\n",
    "register_data(\"mycvdata\", call_my_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-31T09:29:07.854271Z",
     "start_time": "2022-03-31T09:29:07.851771Z"
    },
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## <span id=\"model\">2. Build a model</span>\n",
    "We provide a function `federatedscope.register.register_model` to make your model available with three steps: (we take `ConvNet2` as an example)\n",
    "\n",
    "* Step1: build your model with Pytorch or Tensorflow and instantiate your model class with config and data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-31T10:13:12.787164Z",
     "start_time": "2022-03-31T10:13:12.315611Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "class MyNet(torch.nn.Module):\n",
    "    def __init__(self,\n",
    "                 in_channels,\n",
    "                 h=32,\n",
    "                 w=32,\n",
    "                 hidden=2048,\n",
    "                 class_num=10,\n",
    "                 use_bn=True):\n",
    "        super(MyNet, self).__init__()\n",
    "        self.conv1 = torch.nn.Conv2d(in_channels, 32, 5, padding=2)\n",
    "        self.conv2 = torch.nn.Conv2d(32, 64, 5, padding=2)\n",
    "        self.fc1 = torch.nn.Linear((h // 2 // 2) * (w // 2 // 2) * 64, hidden)\n",
    "        self.fc2 = torch.nn.Linear(hidden, class_num)\n",
    "        self.relu = torch.nn.ReLU(inplace=True)\n",
    "        self.maxpool = torch.nn.MaxPool2d(2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.maxpool(self.relu(x))\n",
    "        x = self.conv2(x)\n",
    "        x = self.maxpool(self.relu(x))\n",
    "        x = torch.nn.Flatten()(x)\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "def load_my_net(model_config, data_shape):\n",
    "    # You can also build models without local_data\n",
    "    model = MyNet(in_channels=data_shape[1],\n",
    "                  h=data_shape[2],\n",
    "                  w=data_shape[3],\n",
    "                  hidden=model_config.hidden,\n",
    "                  class_num=model_config.out_channels)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "* Step2: register your model with a keyword, such as `\"mynet\"`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-31T10:13:12.791526Z",
     "start_time": "2022-03-31T10:13:12.788549Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from federatedscope.register import register_model\n",
    "\n",
    "def call_my_net(model_config, data_shape):\n",
    "    if model_config.type == \"mycnn\":\n",
    "        model = load_my_net(model_config, data_shape)\n",
    "        return model\n",
    "\n",
    "register_model(\"mycnn\", call_my_net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-31T09:29:10.271414Z",
     "start_time": "2022-03-31T09:29:10.269302Z"
    },
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## <span id=\"trainer\">3. Create a trainer</span>\n",
    "\n",
    "FederatedScope decouples the local learning process and details of FL communication and schedule, allowing users to freely customize the local learning algorithms via the `Trainer`. We recommend user build trainer by inheriting `federatedscope.core.trainers.trainer.GeneralTorchTrainer`, for more details, please see [Trainer](https://federatedscope.io/docs/trainer/). Similarly, we provide `federatedscope.register.register_trainer` to make your customized trainer available:\n",
    "\n",
    "* Step1: build your trainer by inheriting `GeneralTrainer`. Our `GeneralTrainer` already supports many different usages, for the advanced user, please see [federatedscope.core.trainers.trainer.GeneralTrainer]() for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-31T10:13:13.181147Z",
     "start_time": "2022-03-31T10:13:12.792631Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from federatedscope.core.trainers import GeneralTorchTrainer\n",
    "\n",
    "class MyTrainer(GeneralTorchTrainer):\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "* Step2: register your trainer with a keyword, such as `\"mytrainer\"`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-31T10:13:13.185648Z",
     "start_time": "2022-03-31T10:13:13.182604Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from federatedscope.register import register_trainer\n",
    "\n",
    "def call_my_trainer(trainer_type):\n",
    "    if trainer_type == 'mycvtrainer':\n",
    "        trainer_builder = MyTrainer\n",
    "        return trainer_builder\n",
    "\n",
    "register_trainer('mycvtrainer', call_my_trainer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## <span id=\"metric\">4. Introduce more evaluation metrics</span>\n",
    "We provide a number of metrics to monitor the entire federal learning process. You just need to list the name of the metric you want in `cfg.eval.metrics`. We currently support metrics such as loss, accuracy, etc. (See [federatedscope.core.evaluator](federatedscope/core/evaluator.py) for more details).\n",
    "\n",
    "We also provide a function `federatedscope.register.register_metric` to make your evaluation metrics available with three steps:\n",
    "\n",
    "* Step1: build your metric (see [federatedscope.core.context](federatedscope/core/context.py) for more about `ctx`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-31T10:13:13.189195Z",
     "start_time": "2022-03-31T10:13:13.187033Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def cal_my_metric(ctx, **kwargs):\n",
    "    return ctx[\"num_train_data\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "* Step2: register your metric with a keyword, such as `\"mymetric\"`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-31T10:13:13.193453Z",
     "start_time": "2022-03-31T10:13:13.190519Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from federatedscope.register import register_metric\n",
    "\n",
    "def call_my_metric(types):\n",
    "    if \"mymetric\" in types:\n",
    "        metric_builder = cal_my_metric\n",
    "        return \"mymetric\", metric_builder\n",
    "\n",
    "register_metric(\"mymetric\", call_my_metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Let's start!\n",
    "* Set your data, model, trainer and metric first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-31T10:13:13.219711Z",
     "start_time": "2022-03-31T10:13:13.195532Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from federatedscope.core.configs.config import global_cfg\n",
    "\n",
    "cfg = global_cfg.clone()\n",
    "\n",
    "cfg.data.type = 'mycvdata'\n",
    "cfg.data.root = 'data'\n",
    "cfg.data.transform = [['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]]\n",
    "cfg.model.type = 'mycnn'\n",
    "cfg.model.out_channels = 10\n",
    "cfg.trainer.type = 'mycvtrainer'\n",
    "cfg.eval.metric = ['mymetric']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "* Configure other options in `cfg`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-31T10:13:13.225301Z",
     "start_time": "2022-03-31T10:13:13.221148Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "cfg.use_gpu = False\n",
    "cfg.best_res_update_round_wise_key = \"test_loss\"\n",
    "\n",
    "cfg.federate.mode = 'standalone'\n",
    "cfg.federate.local_update_steps = 5\n",
    "cfg.federate.total_round_num = 20\n",
    "cfg.federate.sample_client_num = 5\n",
    "cfg.federate.client_num = 5\n",
    "\n",
    "cfg.train.optimizer.lr = 0.001\n",
    "cfg.train.optimizer.weight_decay = 0.0\n",
    "cfg.grad.grad_clip = 5.0\n",
    "\n",
    "cfg.criterion.type = 'CrossEntropyLoss'\n",
    "cfg.seed = 123\n",
    "cfg.eval.best_res_update_round_wise_key = \"test_loss\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "* Start your FL process!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2022-03-31T10:13:12.142Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from federatedscope.core.auxiliaries.data_builder import get_data\n",
    "from federatedscope.core.auxiliaries.utils import setup_seed\n",
    "from federatedscope.core.auxiliaries.logging import update_logger\n",
    "from federatedscope.core.fed_runner import FedRunner\n",
    "from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls\n",
    "\n",
    "setup_seed(cfg.seed)\n",
    "update_logger(cfg)\n",
    "data, modified_cfg = get_data(cfg)\n",
    "cfg.merge_from_other_cfg(modified_cfg)\n",
    "Fed_runner = FedRunner(data=data,\n",
    "                       server_class=get_server_cls(cfg),\n",
    "                       client_class=get_client_cls(cfg),\n",
    "                       config=cfg.clone())\n",
    "Fed_runner.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
